# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import torch
import torch.nn.functional as F

from megatron.core import tensor_parallel, parallel_state

from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl

from transformer_engine.pytorch.distributed import checkpoint, checkpointVirance
# HACK(huang.huang): recompute/variance for mlp in moe with fp8/bf16: 
# support mlp_rms_recompute,  which combine rms, mlp into one checkpoint;
# add new arg "no_recompute" to avoid repated recompute for sharedEXP while 
# moe_layer is already recomputed outsides
def MLP_forward(self, hidden_states, norm_func=None, no_recompute=False):
    """
    Perform the forward pass through the MLP block.
    Args:
    hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
        b is batch size, and h is hidden size.
    norm_func (function): whether to do layernorm inner MLP instead of transformerlayer.
    no_recompute (bool): default is False. only set to True when is sharedEXP, 
                        to avoid repeated recomputation between this mlp and moe_layer 
    """
    # [s, b, 4 * h/p]
    def custom_forward(hidden_states):
        if norm_func is not None:
            assert self.config.mlp_rms_recompute
            
            hidden_states= norm_func(hidden_states)
        intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

        if self.config.bias_activation_fusion:
            if self.activation_func == F.gelu:
                if self.config.gated_linear_unit:
                    intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
                else:
                    assert self.config.add_bias_linear is True
                    intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
            elif self.activation_func == F.silu and self.config.gated_linear_unit:
                intermediate_parallel = bias_swiglu_impl(
                    intermediate_parallel,
                    bias_parallel,
                    self.config.activation_func_fp8_input_store,
                )
            else:
                raise ValueError("Only support fusion of gelu and swiglu")
        else:
            if bias_parallel is not None:
                intermediate_parallel = intermediate_parallel + bias_parallel
            if self.config.gated_linear_unit:

                def glu(x):
                    x = torch.chunk(x, 2, dim=-1)
                    return self.config.activation_func(x[0]) * x[1]

                intermediate_parallel = glu(intermediate_parallel)
            else:
                intermediate_parallel = self.activation_func(intermediate_parallel)

        # [s, b, h]
        output, output_bias = self.linear_fc2(intermediate_parallel)
        return output, output_bias
    
    if norm_func is not None:
        _custom_func_first = lambda x : self.custom_func_first(norm_func(x))
    else:
        _custom_func_first = lambda x : self.custom_func_first(x)# use lambda to create new func instead of method object which can't add new attribute
    if no_recompute: #avoid to recompute under another recompute context outside this function, like in sharedExp
        return custom_forward(hidden_states)
    
    if self.config.mlp_recompute:
        if self.config.fp8:
            if self.config.recompute_variance:
                output, output_bias = checkpointVirance(
                    _custom_func_first,
                    self.linear_fc2,
                    hidden_states,
                    distribute_saved_activations=self.config.distribute_saved_activations,
                    get_rng_state_tracker=tensor_parallel.random.get_cuda_rng_tracker,
                    tp_group=parallel_state.get_tensor_model_parallel_group(),
                )
            else:
                output, output_bias = checkpoint(
                    custom_forward, 
                    hidden_states,
                    distribute_saved_activations=self.config.distribute_saved_activations,
                    get_rng_state_tracker=tensor_parallel.random.get_cuda_rng_tracker,
                    tp_group=parallel_state.get_tensor_model_parallel_group(),
                    )
        else:
            if self.config.recompute_variance:
                output, output_bias = tensor_parallel.checkpointVirance(
                    _custom_func_first, self.linear_fc2, False, hidden_states)
            else:
                output, output_bias = tensor_parallel.checkpoint(
                    custom_forward, False, hidden_states)
    else:
        output, output_bias = custom_forward(hidden_states)
    return output, output_bias
## HACK(huang.huang)

# HACK(huang.huang): seperate linear1 and act from mlp, to support potential recoumpute variance,
# which need a separated linear2
def MLP_custom_func_first(self, hidden_states):
    intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

    if self.config.bias_activation_fusion:
        if self.activation_func == F.gelu:
            if self.config.gated_linear_unit:
                intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
            else:
                assert self.config.add_bias_linear is True
                intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
        elif self.activation_func == F.silu and self.config.gated_linear_unit:
            intermediate_parallel = bias_swiglu_impl(
                intermediate_parallel,
                bias_parallel,
                self.config.activation_func_fp8_input_store,
            )
        else:
            raise ValueError("Only support fusion of gelu and swiglu")
    else:
        if bias_parallel is not None:
            intermediate_parallel = intermediate_parallel + bias_parallel
        if self.config.gated_linear_unit:

            def glu(x):
                x = torch.chunk(x, 2, dim=-1)
                return self.config.activation_func(x[0]) * x[1]

            intermediate_parallel = glu(intermediate_parallel)
        else:
            intermediate_parallel = self.activation_func(intermediate_parallel)
    
    return intermediate_parallel
## HACK(huang.huang)

from transformer_engine.musa.pytorch.utils import replace_attr, add_attr
from megatron.core.transformer.mlp import MLP
replace_attr(MLP,"forward", MLP_forward)
add_attr(MLP,"custom_func_first", MLP_custom_func_first)